#include <iostream>
#include <stdio.h>
// #define CATCH_CONFIG_MAIN
// #include <catch.h>
#include <math.h>
#include <convolution_3x3s1dw.h>
#include <convolution_3x3s2dw.h>
#include <opencv2/opencv.hpp>
using namespace std;
using namespace cv;

bool cmp(float x, float y){
    if(fabs(x - y) < 0.01){
        return true;
    }
    else{
        return false;
    }
}

// -0.2498,  0.1771, -0.1917,
//          -0.2492,  0.1465,  0.1281,
 //          -0.2118, -0.2419,  0.0006

float a[500]={2.8499e+00, 3.6390e+00, 2.6139e+00, 4.8630e+00, 8.1392e+00,
           8.5605e+00, 3.4631e+00, 1.4378e+00, 8.8663e+00, 7.7009e+00,       
          6.1214e+00, 3.3853e+00, 2.6189e+00, 4.1047e+00, 8.5599e+00,        
           5.0925e+00, 9.6310e+00, 2.8463e+00, 4.3556e+00, 1.1073e+00,       
          2.0438e+00, 2.0238e+00, 9.6939e+00, 5.8516e+00, 3.2842e+00,        
           6.1771e+00, 7.2959e+00, 9.0467e+00, 5.1330e+00, 1.8687e+00,       
          1.3059e+00, 8.6773e-01, 8.3100e+00, 8.4838e-01, 3.6533e+00,        
           1.4467e+00, 8.3336e+00, 5.6641e+00, 6.9674e+00, 5.1799e+00,       
          1.9111e+00, 7.3623e+00, 5.5061e+00, 2.7012e-01, 7.4667e+00,
           2.9075e+00, 6.6789e+00, 6.4976e+00, 6.6865e+00, 4.6528e+00,
          1.8918e+00, 8.4234e+00, 2.7070e+00, 7.9141e+00, 4.9671e+00,
           6.7656e+00, 1.9369e-02, 5.7822e+00, 5.0519e+00, 6.9375e+00,
          8.6307e+00, 3.8688e+00, 8.7955e+00, 9.7277e+00, 7.5926e+00,
           6.8476e+00, 4.4097e+00, 5.1013e+00, 2.2500e+00, 5.9718e+00,
          9.2950e+00, 1.3326e-01, 1.0708e+00, 8.9223e+00, 2.8403e-01,
           6.5356e+00, 9.0030e+00, 8.4306e+00, 4.2200e+00, 1.7165e-01,
          9.5864e-02, 1.6788e+00, 6.7679e+00, 7.5707e+00, 8.9194e-01,
           8.9522e+00, 4.2658e+00, 7.0559e+00, 1.1105e+00, 3.1683e+00,
          4.0281e+00, 5.4696e+00, 2.9424e-01, 4.6289e+00, 2.1862e+00,
           9.1369e+00, 4.6429e+00, 8.5765e-02, 1.8431e+00, 3.6130e+00,

         2.5156e-01, 9.8141e+00, 6.6366e+00, 8.6019e+00, 9.2755e+00,
           2.9920e+00, 2.6030e+00, 7.0944e+00, 4.7661e-01, 5.1468e+00,
          1.5148e+00, 5.9397e+00, 8.5368e+00, 2.0255e-01, 1.5071e+00,
           4.2323e+00, 6.9482e-01, 7.6069e+00, 4.5798e-02, 3.6494e+00,
          4.0629e+00, 5.4817e+00, 1.2063e+00, 8.7509e+00, 4.6225e+00,
           8.2445e+00, 2.0703e+00, 4.9311e+00, 2.0684e+00, 3.7438e+00,
          7.5948e+00, 4.5828e+00, 3.0250e+00, 3.7534e+00, 9.0463e+00,
           6.8584e+00, 2.0255e+00, 2.9680e+00, 2.2813e+00, 4.5472e+00,
          9.7360e+00, 6.5053e+00, 1.9779e-01, 7.7927e+00, 7.7469e+00,
           2.8253e+00, 2.3396e+00, 1.3664e+00, 2.6282e-01, 2.4464e-01,
          1.9636e-01, 6.8721e-01, 3.2808e+00, 4.7743e+00, 1.5507e+00,
           5.9482e+00, 3.3815e+00, 9.5210e+00, 7.9601e+00, 4.1454e+00,
          3.4434e+00, 2.7506e+00, 1.2870e-01, 4.4299e+00, 1.8710e+00,
           3.6229e+00, 4.7827e+00, 4.3340e+00, 1.2512e+00, 3.9524e-03,
          9.4953e+00, 1.1506e+00, 7.5259e+00, 1.0886e+00, 9.0524e+00,
           3.0334e+00, 7.5643e+00, 8.9976e+00, 3.8048e+00, 1.8353e+00,
          4.7590e+00, 5.4848e+00, 3.8179e+00, 9.3669e+00, 7.2123e+00,
           8.3845e-01, 6.7436e+00, 7.3066e+00, 6.3901e+00, 5.1157e+00,
          7.1172e+00, 2.0329e+00, 9.2324e+00, 6.1155e+00, 3.3665e+00,
           6.4022e+00, 5.9759e+00, 1.5779e+00, 9.8629e+00, 9.6913e+00,

         7.3756e+00, 1.0433e-01, 3.3415e+00, 7.5005e-01, 3.6763e+00,
           4.2819e+00, 9.6236e+00, 9.4480e+00, 3.6287e+00, 1.2726e-01,
          8.4597e+00, 2.9986e+00, 7.5110e+00, 5.1281e+00, 3.4420e+00,
           7.1611e-01, 6.9129e+00, 8.0790e+00, 1.4883e+00, 2.9796e+00,
          3.8329e-01, 7.7019e+00, 1.3253e+00, 8.6503e+00, 9.1301e+00,
           2.3932e+00, 4.4542e+00, 3.4566e+00, 3.7349e+00, 1.6372e+00,
          7.8055e+00, 6.5166e+00, 9.5189e+00, 3.7081e+00, 3.3335e+00,
           4.1696e+00, 2.3822e+00, 3.2900e-01, 1.4813e+00, 5.9863e+00,
          1.0864e+00, 3.6538e+00, 8.1336e+00, 4.7854e+00, 5.8335e+00,
           7.6314e+00, 6.6008e+00, 9.0123e-01, 1.9625e-01, 9.7302e+00,
          8.0557e+00, 1.0368e+00, 8.1553e+00, 8.0094e+00, 6.1392e-01,
           4.1316e+00, 3.6023e+00, 2.9096e+00, 4.8043e+00, 2.1765e+00,
          2.6747e+00, 9.7829e+00, 6.1725e+00, 4.7258e+00, 3.3210e-01,
           4.7222e-02, 5.2130e+00, 5.1229e+00, 8.8194e+00, 7.8241e+00,
          2.9847e+00, 5.1706e+00, 7.6604e+00, 3.8066e+00, 7.5989e+00,
           2.5059e+00, 3.0046e+00, 4.9174e+00, 5.3902e+00, 5.4317e+00,
          1.0222e+00, 1.9615e-02, 2.3908e+00, 6.8977e+00, 8.0379e+00,
           3.7594e+00, 1.7204e+00, 1.5769e+00, 2.7559e+00, 6.8118e+00,
          3.0197e+00, 9.6618e+00, 1.1675e+00, 9.6470e+00, 1.4986e+00,
           1.0494e+00, 1.4214e+00, 1.7342e+00, 8.4698e-01, 9.4091e+00,

         6.2028e+00, 8.3083e+00, 8.5540e+00, 3.8282e-01, 4.9467e+00,
           6.7709e-01, 8.4318e-02, 6.7736e+00, 6.9383e+00, 2.9021e+00,
          6.4828e+00, 5.5577e+00, 9.5488e+00, 6.6526e+00, 7.6764e+00,
           3.0443e+00, 5.8916e+00, 8.6922e+00, 9.6235e+00, 6.6525e+00,
          6.3568e+00, 5.2427e+00, 3.2801e-01, 7.7384e+00, 9.6233e+00,
           4.9793e+00, 6.5436e+00, 3.4679e+00, 3.1126e+00, 9.9355e+00,
          6.0786e+00, 1.4998e-01, 1.1771e+00, 5.4363e+00, 6.5534e+00,
           1.8653e+00, 3.7479e+00, 8.7564e+00, 8.0491e+00, 4.5894e+00,
          8.2079e+00, 3.9144e+00, 6.0671e+00, 2.5066e-01, 5.3569e+00,
           3.2749e-01, 1.4315e-01, 4.4371e+00, 1.6388e-01, 1.7192e+00,
          4.0370e+00, 5.7768e+00, 1.6084e+00, 7.7434e+00, 6.3707e+00,
           5.0826e+00, 9.7197e+00, 2.2553e+00, 8.4985e+00, 7.9301e+00,
          1.2272e+00, 6.4986e+00, 1.7879e+00, 5.0290e+00, 1.0091e+00,
           5.7427e+00, 4.6768e+00, 3.2713e+00, 3.0056e+00, 9.0378e+00,
          5.4328e+00, 2.0515e+00, 4.6080e+00, 1.7228e+00, 3.3250e+00,
           7.0508e-01, 7.5573e+00, 1.5312e+00, 1.1992e+00, 4.8216e+00,
          9.6407e+00, 4.5941e+00, 6.9312e+00, 2.1987e+00, 3.6255e+00,
           1.0292e+00, 4.5049e+00, 3.5179e+00, 6.9108e+00, 3.0757e+00,
          2.7913e+00, 6.9554e+00, 1.1883e+00, 1.8356e+00, 6.9784e+00,
           7.0870e+00, 2.6204e+00, 3.1350e+00, 7.4898e+00, 2.9948e+00};

float b[200]={0.1472,  0.3213,  0.2204,
           0.2979,  0.1612, -0.1585,
          -0.0651,  0.1828, -0.3247,


        -0.0449,  0.2847,  0.2774,
           0.1953, -0.1766, -0.2018,
          -0.3108,  0.1091,  0.1150,


        -0.2269, -0.1191, -0.3120,
           0.1366, -0.1999, -0.2775,
          -0.3105, -0.3258,  0.2776,


         0.0260,  0.2830,  0.1810,
          -0.3244,  0.2250, -0.0682,
          -0.0398, -0.1173,  0.2852};

float c[400]={
           1.2085,  3.1981,  5.1025,  5.0754,
           1.7322,  3.3310,  1.9532,  5.9851,
           2.6423,  2.5781,  5.3985,  4.5962,
           5.1760,  8.4596,  4.1937,  6.8080,

          1.6224,  7.1635,  0.2651,  0.9507,
          -0.5171,  3.5023,  1.0306,  0.7481,
          -0.0305,  4.5023,  0.0396, -3.1400,
           0.3556,  1.8906,  0.6059, -0.0768,

         -6.5173, -3.6445, -8.3167, -6.9971,
          -3.5658, -7.0114, -7.2510, -5.0279,
          -6.8944, -8.2571, -4.7053, -3.9287,
          -6.1099, -4.6488, -5.2475, -6.8962,

          1.7828,  0.9255, -0.9728,  2.7836,
           0.6343,  5.5918,  0.6711,  1.4410,
           1.9987,  1.6110, -0.7081, -1.6489,
           1.6337,  0.8185,  2.0825,  0.7815
};


int main(){
    const int inw = 10;
    const int inh = 10;
    const int inch = 4;
    const int kw = 3;
    const int kh = 3;
    int stride = 2;
    const int outw = (inw - kw) / stride + 1;
    const int outh = (inh - kh) / stride + 1;
    const int outch = 4;

    //5x5x3
    float *src = new float[inw * inh * inch];
    //3x3x4
    float *kernel = new float[kw * kh * inch];
    //3x3x4
    float *dest = new float[outw * outh * outch];

    //赋值
    for(int i = 0; i < inw * inh * inch; i++){
        src[i] = a[i];
    }

    for(int i = 0; i < kw * kh * inch; i++){
        kernel[i] = b[i];
    }
    
    int64 st = cvGetTickCount();

    // for(int i = 0; i < 10; i++){
    //     //memset(dest, 0, sizeof(dest));
    //     for(int j = 0; j < outw * outh * outch; j++) dest[j] = 0.f;
    //     convdepthwise3x3s1Neon(src, inw, inh, inch, kernel, dest, outw, outh, outch);
    // }
    convdepthwise3x3s2Neon(src, inw, inh, inch, kernel, dest, outw, outh, outch);
    
    double duration = (cv::getTickCount() - st) / cv::getTickFrequency() * 100;

    for(int i = 0; i < outw * outh * outch ; i++){
        bool flag = cmp(dest[i], c[i]);
        if(flag == false){
            printf("WA: %d\n", i);
            printf("Expected: %.4f, ConvOutput: %.4f\n", c[i], dest[i]);
        }
    }

    printf("Time: %.5f\n", duration);

    for(int i = 0; i < outw * outh * outch; i++){
        printf("%.4f ", dest[i]);
    }

    printf("\n");
    free(src);
    free(kernel);
    free(dest);

    return 0;
}